{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "eda8d5ac",
   "metadata": {},
   "source": [
    "| [01_word_embedding/03_Word2Vec.ipynb](https://github.com/shibing624/nlp-tutorial/blob/main/01_word_embedding/03_Word2Vec.ipynb)  | 基于gensim使用word2vec模型  |[Open In Colab](https://colab.research.google.com/github/shibing624/nlp-tutorial/blob/main/01_word_embedding/03_Word2Vec.ipynb) |\n",
    "\n",
    "# Word2Vec\n",
    "\n",
    "这节通过gensim和pytorch训练日常使用的Word2Vec模型。\n",
    "\n",
    "## Gensim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2da65b51",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gensim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "29f35b1e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sentences = [['first', 'sentence'], ['second', 'sentence']]\n",
    "\n",
    "# 传入文本数据，直接初始化并训练Word2Vec模型\n",
    "model = gensim.models.Word2Vec(sentences, min_count=1)\n",
    "model.wv.key_to_index['first']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "dbfff540",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-0.023671666"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 词之间的相似度\n",
    "model.wv.similarity('first', 'second')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0f3a774",
   "metadata": {},
   "source": [
    "### 例子1：gensim训练英文word2vec模型\n",
    "\n",
    "gensim下的word2vec模型可以继续训练，下面的例子把常用参数写上："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ebb96886",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[['human', 'interface', 'computer'], ['survey', 'user', 'computer', 'system', 'response', 'time'], ['eps', 'user', 'interface', 'system'], ['system', 'human', 'system', 'eps'], ['user', 'response', 'time'], ['trees'], ['graph', 'trees'], ['graph', 'minors', 'trees'], ['graph', 'minors', 'survey']]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<gensim.models.word2vec.Word2Vec at 0x7fd9d0c67b20>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from gensim.test.utils import common_texts\n",
    "from gensim.models import Word2Vec\n",
    "\n",
    "print(common_texts[:200])\n",
    "model = Word2Vec(sentences=common_texts, vector_size=100,\n",
    "                 window=5, min_count=1, workers=4)\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6eaad73d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<gensim.models.word2vec.Word2Vec at 0x7fd9ce7681c0>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.save(\"word2vec.model\")\n",
    "\n",
    "# 先保存，再继续接力训练\n",
    "model = Word2Vec.load(\"word2vec.model\")\n",
    "model.train([[\"hello\", \"world\"]], total_examples=1, epochs=1)\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1df2df6c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.00515774, -0.00667028, -0.0077791 ,  0.00831315, -0.00198292,\n",
       "       -0.00685696, -0.0041556 ,  0.00514562, -0.00286997, -0.00375075,\n",
       "        0.0016219 , -0.0027771 , -0.00158482,  0.0010748 , -0.00297881,\n",
       "        0.00852176,  0.00391207, -0.00996176,  0.00626142, -0.00675622,\n",
       "        0.00076966,  0.00440552, -0.00510486, -0.00211128,  0.00809783,\n",
       "       -0.00424503, -0.00763848,  0.00926061, -0.00215612, -0.00472081,\n",
       "        0.00857329,  0.00428458,  0.0043261 ,  0.00928722, -0.00845554,\n",
       "        0.00525685,  0.00203994,  0.0041895 ,  0.00169839,  0.00446543,\n",
       "        0.00448759,  0.0061063 , -0.00320303, -0.00457706, -0.00042664,\n",
       "        0.00253447, -0.00326412,  0.00605948,  0.00415534,  0.00776685,\n",
       "        0.00257002,  0.00811904, -0.00138761,  0.00808028,  0.0037181 ,\n",
       "       -0.00804967, -0.00393476, -0.0024726 ,  0.00489447, -0.00087241,\n",
       "       -0.00283173,  0.00783599,  0.00932561, -0.0016154 , -0.00516075,\n",
       "       -0.00470313, -0.00484746, -0.00960562,  0.00137242, -0.00422615,\n",
       "        0.00252744,  0.00561612, -0.00406709, -0.00959937,  0.00154715,\n",
       "       -0.00670207,  0.0024959 , -0.00378173,  0.00708048,  0.00064041,\n",
       "        0.00356198, -0.00273993, -0.00171105,  0.00765502,  0.00140809,\n",
       "       -0.00585215, -0.00783678,  0.00123304,  0.00645651,  0.00555797,\n",
       "       -0.00897966,  0.00859466,  0.00404815,  0.00747178,  0.00974917,\n",
       "       -0.0072917 , -0.00904259,  0.0058377 ,  0.00939395,  0.00350795],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vector1 = model.wv['computer']  # get numpy vector of a word\n",
    "vector1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3362b1cf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('system', 0.21617142856121063),\n",
       " ('survey', 0.044689200818538666),\n",
       " ('interface', 0.01520337350666523),\n",
       " ('time', 0.0019510575802996755),\n",
       " ('trees', -0.03284314647316933),\n",
       " ('human', -0.0742427185177803),\n",
       " ('response', -0.09317588806152344),\n",
       " ('graph', -0.09575346857309341),\n",
       " ('eps', -0.10513805598020554),\n",
       " ('user', -0.16911622881889343)]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sims = model.wv.most_similar('computer', topn=10)  # get other similar words\n",
    "sims"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e0ca7a2",
   "metadata": {},
   "source": [
    "仅仅保存模型训练好的词向量键值对，通过 `KeyedVectors` 快速加载到内存，计算词的向量值："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5c043fbc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.00515774, -0.00667028, -0.0077791 ,  0.00831315, -0.00198292,\n",
       "       -0.00685696, -0.0041556 ,  0.00514562, -0.00286997, -0.00375075,\n",
       "        0.0016219 , -0.0027771 , -0.00158482,  0.0010748 , -0.00297881,\n",
       "        0.00852176,  0.00391207, -0.00996176,  0.00626142, -0.00675622,\n",
       "        0.00076966,  0.00440552, -0.00510486, -0.00211128,  0.00809783,\n",
       "       -0.00424503, -0.00763848,  0.00926061, -0.00215612, -0.00472081,\n",
       "        0.00857329,  0.00428458,  0.0043261 ,  0.00928722, -0.00845554,\n",
       "        0.00525685,  0.00203994,  0.0041895 ,  0.00169839,  0.00446543,\n",
       "        0.00448759,  0.0061063 , -0.00320303, -0.00457706, -0.00042664,\n",
       "        0.00253447, -0.00326412,  0.00605948,  0.00415534,  0.00776685,\n",
       "        0.00257002,  0.00811904, -0.00138761,  0.00808028,  0.0037181 ,\n",
       "       -0.00804967, -0.00393476, -0.0024726 ,  0.00489447, -0.00087241,\n",
       "       -0.00283173,  0.00783599,  0.00932561, -0.0016154 , -0.00516075,\n",
       "       -0.00470313, -0.00484746, -0.00960562,  0.00137242, -0.00422615,\n",
       "        0.00252744,  0.00561612, -0.00406709, -0.00959937,  0.00154715,\n",
       "       -0.00670207,  0.0024959 , -0.00378173,  0.00708048,  0.00064041,\n",
       "        0.00356198, -0.00273993, -0.00171105,  0.00765502,  0.00140809,\n",
       "       -0.00585215, -0.00783678,  0.00123304,  0.00645651,  0.00555797,\n",
       "       -0.00897966,  0.00859466,  0.00404815,  0.00747178,  0.00974917,\n",
       "       -0.0072917 , -0.00904259,  0.0058377 ,  0.00939395,  0.00350795],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from gensim.models import KeyedVectors\n",
    "# Store just the words + their trained embeddings.\n",
    "word_vectors = model.wv\n",
    "word_vectors.save(\"word2vec.wordvectors\")\n",
    "# Load back with memory-mapping = read-only, shared across processes.\n",
    "wv = KeyedVectors.load(\"word2vec.wordvectors\", mmap='r')\n",
    "vector2 = wv['computer']  # Get numpy vector of a word\n",
    "vector2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b19d464d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "compare = vector1 == vector2\n",
    "compare.all()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ff2338f",
   "metadata": {},
   "source": [
    "向量结果是一样的。\n",
    "\n",
    "### 例子2：gensim训练中文word2vec模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ed0f63a3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[['1,本报记者',\n",
       "  '发自',\n",
       "  '上海',\n",
       "  '国外',\n",
       "  '媒体',\n",
       "  '昨日',\n",
       "  '报道',\n",
       "  '澳大利亚',\n",
       "  '银行',\n",
       "  'acq',\n",
       "  'arie',\n",
       "  '预计',\n",
       "  '推出',\n",
       "  '中国',\n",
       "  '人民币',\n",
       "  '10',\n",
       "  '亿元',\n",
       "  '商业',\n",
       "  '住房',\n",
       "  '抵押',\n",
       "  '贷款',\n",
       "  '资产',\n",
       "  '证券化',\n",
       "  '计划',\n",
       "  '有关部门',\n",
       "  '批准',\n",
       "  '将是',\n",
       "  '海外',\n",
       "  '资金',\n",
       "  '首次',\n",
       "  '此项',\n",
       "  '计划',\n",
       "  '市场分析',\n",
       "  '人士',\n",
       "  '计划',\n",
       "  '预计',\n",
       "  '中国',\n",
       "  '监管部门',\n",
       "  '阻力',\n",
       "  '考虑到',\n",
       "  '交易',\n",
       "  '相关',\n",
       "  '高昂',\n",
       "  '固定成本',\n",
       "  '人民币',\n",
       "  '10',\n",
       "  '亿元',\n",
       "  '可能是',\n",
       "  '最低',\n",
       "  '金额',\n",
       "  '银行',\n",
       "  '原本',\n",
       "  '计划',\n",
       "  '2006',\n",
       "  '年初',\n",
       "  '中国',\n",
       "  '推出',\n",
       "  'macquarie',\n",
       "  'anda',\n",
       "  '房地产',\n",
       "  '投资信托',\n",
       "  '计划',\n",
       "  '香港特区',\n",
       "  '证监会',\n",
       "  '否决',\n",
       "  '该银行',\n",
       "  '中国',\n",
       "  '房地产投资',\n",
       "  '基金',\n",
       "  '首席',\n",
       "  '投资',\n",
       "  '执行官',\n",
       "  '此前',\n",
       "  '开发商',\n",
       "  '行列',\n",
       "  '竟是',\n",
       "  '金融机构',\n",
       "  '项目',\n",
       "  '投融资',\n",
       "  '资本运作',\n",
       "  '才是',\n",
       "  '特长'],\n",
       " ['2,复旦',\n",
       "  '新浪',\n",
       "  '本报记者',\n",
       "  '杨国强',\n",
       "  '1984年',\n",
       "  '相貌端正',\n",
       "  '复旦大学',\n",
       "  '新闻系',\n",
       "  '大学',\n",
       "  '同学',\n",
       "  '回忆说',\n",
       "  '内向',\n",
       "  '做事',\n",
       "  '很有',\n",
       "  '生活',\n",
       "  '学习',\n",
       "  '很有',\n",
       "  '计划性',\n",
       "  '大学毕业',\n",
       "  '上海',\n",
       "  '电视台',\n",
       "  '当了',\n",
       "  '两年',\n",
       "  '记者',\n",
       "  '赴美',\n",
       "  '求学',\n",
       "  '先在',\n",
       "  '奥克拉荷',\n",
       "  '大学',\n",
       "  '拿了',\n",
       "  '新闻学',\n",
       "  '硕士',\n",
       "  '再到',\n",
       "  '德州',\n",
       "  '奥斯汀',\n",
       "  '大学',\n",
       "  '拿了',\n",
       "  '财务',\n",
       "  '专业',\n",
       "  '硕士',\n",
       "  '转入',\n",
       "  '企业界',\n",
       "  '早就',\n",
       "  '美国',\n",
       "  '会计师',\n",
       "  '协会',\n",
       "  '美国',\n",
       "  '注册会计师',\n",
       "  '1993',\n",
       "  '1999',\n",
       "  '普华永道',\n",
       "  '工作',\n",
       "  '负责',\n",
       "  '硅谷',\n",
       "  '地区',\n",
       "  '高科技公司',\n",
       "  '提供',\n",
       "  '审计',\n",
       "  '服务',\n",
       "  '商业',\n",
       "  '咨询',\n",
       "  '在此期间',\n",
       "  '参与',\n",
       "  '多家',\n",
       "  '高科技公司',\n",
       "  '上市',\n",
       "  '1999',\n",
       "  '2000',\n",
       "  '财务',\n",
       "  '副总裁',\n",
       "  '身份',\n",
       "  '加盟',\n",
       "  '新浪',\n",
       "  '运作',\n",
       "  '新浪',\n",
       "  '美国',\n",
       "  '上市',\n",
       "  '参与',\n",
       "  '设计',\n",
       "  '中国',\n",
       "  '互联网',\n",
       "  '公司',\n",
       "  '海外',\n",
       "  '上市',\n",
       "  '结构',\n",
       "  '新浪',\n",
       "  '余家',\n",
       "  '中国概念股',\n",
       "  '上市',\n",
       "  '提供',\n",
       "  '借鉴',\n",
       "  '2001年',\n",
       "  '担任',\n",
       "  '新浪',\n",
       "  'cfo',\n",
       "  '2000',\n",
       "  '2001',\n",
       "  '推动',\n",
       "  '新浪',\n",
       "  '变了',\n",
       "  '照搬',\n",
       "  '美国',\n",
       "  '网络广告',\n",
       "  '销售',\n",
       "  '方式',\n",
       "  '改为',\n",
       "  '符合',\n",
       "  '中国',\n",
       "  '广告主',\n",
       "  '需求',\n",
       "  '时段',\n",
       "  '流量',\n",
       "  '模式',\n",
       "  '广告',\n",
       "  '主和',\n",
       "  '客户',\n",
       "  '肯定',\n",
       "  '这一',\n",
       "  '举措',\n",
       "  '新浪',\n",
       "  '互联网',\n",
       "  '广告',\n",
       "  '市场',\n",
       "  '领先地位',\n",
       "  '奠定',\n",
       "  '基础',\n",
       "  '2003年',\n",
       "  '主持',\n",
       "  '谈判',\n",
       "  '两次',\n",
       "  '并购',\n",
       "  '新浪',\n",
       "  '无线',\n",
       "  '市场',\n",
       "  '后来居上',\n",
       "  '稳定的',\n",
       "  '利润',\n",
       "  '2004年',\n",
       "  '6月',\n",
       "  '兼任',\n",
       "  '新浪',\n",
       "  '联席',\n",
       "  '营长',\n",
       "  '负责',\n",
       "  '网站',\n",
       "  '运营',\n",
       "  '广告',\n",
       "  '销售',\n",
       "  '市场',\n",
       "  '广告',\n",
       "  '销售',\n",
       "  '部门',\n",
       "  '重组',\n",
       "  '进了',\n",
       "  '系统化',\n",
       "  '销售',\n",
       "  '管理体系',\n",
       "  '新浪',\n",
       "  '2005年',\n",
       "  '广告',\n",
       "  '销售',\n",
       "  '业绩',\n",
       "  '增长率',\n",
       "  '年来',\n",
       "  '首次',\n",
       "  '超过',\n",
       "  '竞争对手',\n",
       "  '推动',\n",
       "  '博客',\n",
       "  '发展计划',\n",
       "  '赢得了',\n",
       "  '新浪博客',\n",
       "  '成功',\n",
       "  '2005',\n",
       "  '年度',\n",
       "  '中国',\n",
       "  '杰出',\n",
       "  'cfo',\n",
       "  '2005',\n",
       "  '年度',\n",
       "  '中国',\n",
       "  '广告',\n",
       "  '影响力',\n",
       "  '人物',\n",
       "  '荣誉',\n",
       "  '2005年',\n",
       "  '9月',\n",
       "  '升任',\n",
       "  '新浪',\n",
       "  '裁并',\n",
       "  '兼任',\n",
       "  '首席',\n",
       "  '财务',\n",
       "  '2006年',\n",
       "  '5月',\n",
       "  '10日',\n",
       "  '担任',\n",
       "  '新浪',\n",
       "  'ceo']]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "txt_path = 'data/C000008_test.txt'\n",
    "sentences = [i.split() for i in open(txt_path, 'r', encoding='utf-8').read().split('\\n')]\n",
    "sentences[:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "792bf8dc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'新浪': 0,\n",
       " '中国': 1,\n",
       " '化妆品': 2,\n",
       " '美国': 3,\n",
       " '广告': 4,\n",
       " '碎片': 5,\n",
       " '计划': 6,\n",
       " '地球': 7,\n",
       " '销售': 8,\n",
       " '上市': 9,\n",
       " '5月': 10,\n",
       " '科学家': 11,\n",
       " '皮肤': 12,\n",
       " '改善': 13,\n",
       " '大学': 14,\n",
       " '的产品': 15,\n",
       " '市场': 16,\n",
       " '彗星': 17,\n",
       " '财务': 18,\n",
       " '预计': 19,\n",
       " '参与': 20,\n",
       " '网站': 21,\n",
       " '兼任': 22,\n",
       " '提供': 23,\n",
       " '瓦斯': 24,\n",
       " '高科技公司': 25,\n",
       " '73p': 26,\n",
       " '负责': 27,\n",
       " '植物': 28,\n",
       " '拿了': 29,\n",
       " '2005年': 30,\n",
       " '1999': 31,\n",
       " '距离': 32,\n",
       " '硕士': 33,\n",
       " '首席': 34,\n",
       " '2000': 35,\n",
       " '年度': 36,\n",
       " '10': 37,\n",
       " '报道': 38,\n",
       " '银行': 39,\n",
       " '担任': 40,\n",
       " '推出': 41,\n",
       " '便宜': 42,\n",
       " '之间': 43,\n",
       " '人民币': 44,\n",
       " '2005': 45,\n",
       " '亿元': 46,\n",
       " '商业': 47,\n",
       " '推动': 48,\n",
       " 'cfo': 49,\n",
       " '12日': 50,\n",
       " '宇航局': 51,\n",
       " '互联网': 52,\n",
       " '海外': 53,\n",
       " '上海': 54,\n",
       " '首次': 55,\n",
       " '很有': 56,\n",
       " '3号': 57,\n",
       " '德州': 58,\n",
       " '奥斯汀': 59,\n",
       " '再到': 60,\n",
       " '赴美': 61,\n",
       " '新闻学': 62,\n",
       " '奥克拉荷': 63,\n",
       " '电视台': 64,\n",
       " '先在': 65,\n",
       " '求学': 66,\n",
       " '当了': 67,\n",
       " '两年': 68,\n",
       " '记者': 69,\n",
       " '那一': 70,\n",
       " '专业': 71,\n",
       " '副总裁': 72,\n",
       " '照搬': 73,\n",
       " '变了': 74,\n",
       " '2001': 75,\n",
       " '2001年': 76,\n",
       " '借鉴': 77,\n",
       " '中国概念股': 78,\n",
       " '余家': 79,\n",
       " '结构': 80,\n",
       " '公司': 81,\n",
       " '设计': 82,\n",
       " '运作': 83,\n",
       " '加盟': 84,\n",
       " '身份': 85,\n",
       " '多家': 86,\n",
       " '转入': 87,\n",
       " '在此期间': 88,\n",
       " '咨询': 89,\n",
       " '服务': 90,\n",
       " '审计': 91,\n",
       " '地区': 92,\n",
       " '工作': 93,\n",
       " '普华永道': 94,\n",
       " '1993': 95,\n",
       " '注册会计师': 96,\n",
       " '协会': 97,\n",
       " '会计师': 98,\n",
       " '早就': 99,\n",
       " '企业界': 100,\n",
       " '硅谷': 101,\n",
       " '同学': 102,\n",
       " '大学毕业': 103,\n",
       " '相关': 104,\n",
       " '市场分析': 105,\n",
       " '人士': 106,\n",
       " '监管部门': 107,\n",
       " '阻力': 108,\n",
       " '考虑到': 109,\n",
       " '交易': 110,\n",
       " '高昂': 111,\n",
       " 'macquarie': 112,\n",
       " '固定成本': 113,\n",
       " '可能是': 114,\n",
       " '最低': 115,\n",
       " '金额': 116,\n",
       " '原本': 117,\n",
       " '2006': 118,\n",
       " '此项': 119,\n",
       " '资金': 120,\n",
       " '将是': 121,\n",
       " '批准': 122,\n",
       " '有关部门': 123,\n",
       " '证券化': 124,\n",
       " '资产': 125,\n",
       " '贷款': 126,\n",
       " '抵押': 127,\n",
       " '住房': 128,\n",
       " 'arie': 129,\n",
       " 'acq': 130,\n",
       " '澳大利亚': 131,\n",
       " '昨日': 132,\n",
       " '媒体': 133,\n",
       " '国外': 134,\n",
       " '发自': 135,\n",
       " '年初': 136,\n",
       " 'anda': 137,\n",
       " '计划性': 138,\n",
       " '相貌端正': 139,\n",
       " '才是': 140,\n",
       " '特长': 141,\n",
       " '2,复旦': 142,\n",
       " '本报记者': 143,\n",
       " '杨国强': 144,\n",
       " '1984年': 145,\n",
       " '复旦大学': 146,\n",
       " '房地产': 147,\n",
       " '新闻系': 148,\n",
       " '回忆说': 149,\n",
       " '内向': 150,\n",
       " '做事': 151,\n",
       " '生活': 152,\n",
       " '学习': 153,\n",
       " '资本运作': 154,\n",
       " '投融资': 155,\n",
       " '项目': 156,\n",
       " '金融机构': 157,\n",
       " '竟是': 158,\n",
       " '行列': 159,\n",
       " '开发商': 160,\n",
       " '此前': 161,\n",
       " '执行官': 162,\n",
       " '投资': 163,\n",
       " '基金': 164,\n",
       " '房地产投资': 165,\n",
       " '该银行': 166,\n",
       " '否决': 167,\n",
       " '证监会': 168,\n",
       " '香港特区': 169,\n",
       " '投资信托': 170,\n",
       " '网络广告': 171,\n",
       " '这一': 172,\n",
       " '方式': 173,\n",
       " '很可能': 174,\n",
       " '角质化': 175,\n",
       " '过程': 176,\n",
       " '所需': 177,\n",
       " '时间': 178,\n",
       " '三个月': 179,\n",
       " '会把': 180,\n",
       " '理想': 181,\n",
       " '安全地': 182,\n",
       " '预期': 183,\n",
       " '短期': 184,\n",
       " '都是': 185,\n",
       " '加了': 186,\n",
       " '违禁': 187,\n",
       " '原料': 188,\n",
       " '虽然在': 189,\n",
       " '表皮': 190,\n",
       " '状况': 191,\n",
       " '3,化妆品': 192,\n",
       " '利用': 193,\n",
       " '月球': 194,\n",
       " '20': 195,\n",
       " '多倍': 196,\n",
       " '不会有': 197,\n",
       " '危险': 198,\n",
       " '提醒': 199,\n",
       " '会对': 200,\n",
       " 'n101': 201,\n",
       " '观察': 202,\n",
       " '中最': 203,\n",
       " '明亮': 204,\n",
       " '双筒望远镜': 205,\n",
       " '肉眼': 206,\n",
       " '观察到': 207,\n",
       " '天内': 208,\n",
       " '导致': 209,\n",
       " '最接近': 210,\n",
       " '皮肤病': 211,\n",
       " '成分': 212,\n",
       " '也许': 213,\n",
       " '发现': 214,\n",
       " '相差无几': 215,\n",
       " '配方': 216,\n",
       " '选购': 217,\n",
       " '简单': 218,\n",
       " '办法': 219,\n",
       " '尝试': 220,\n",
       " '检测': 221,\n",
       " '合格': 222,\n",
       " '品牌': 223,\n",
       " '选择': 224,\n",
       " '不良反应': 225,\n",
       " '感觉': 226,\n",
       " '对照': 227,\n",
       " '越好': 228,\n",
       " '质量': 229,\n",
       " '出售': 230,\n",
       " '2.': 231,\n",
       " '绿色': 232,\n",
       " '作成': 233,\n",
       " '形态': 234,\n",
       " '装在': 235,\n",
       " '瓶子': 236,\n",
       " '不可能': 237,\n",
       " '3.': 238,\n",
       " '不含': 239,\n",
       " '防腐剂': 240,\n",
       " '化学成分': 241,\n",
       " '迷信': 242,\n",
       " '纯天然': 243,\n",
       " '宣传': 244,\n",
       " '轨道': 245,\n",
       " '即便是': 246,\n",
       " '改为': 247,\n",
       " '增长率': 248,\n",
       " '并购': 249,\n",
       " '无线': 250,\n",
       " '后来居上': 251,\n",
       " '稳定的': 252,\n",
       " '利润': 253,\n",
       " '2004年': 254,\n",
       " '6月': 255,\n",
       " '联席': 256,\n",
       " '营长': 257,\n",
       " '运营': 258,\n",
       " '部门': 259,\n",
       " '重组': 260,\n",
       " '进了': 261,\n",
       " '系统化': 262,\n",
       " '管理体系': 263,\n",
       " '两次': 264,\n",
       " '谈判': 265,\n",
       " '主持': 266,\n",
       " '主和': 267,\n",
       " '符合': 268,\n",
       " '广告主': 269,\n",
       " '需求': 270,\n",
       " '时段': 271,\n",
       " '流量': 272,\n",
       " '模式': 273,\n",
       " '客户': 274,\n",
       " '2003年': 275,\n",
       " '肯定': 276,\n",
       " '最舒服': 277,\n",
       " '举措': 278,\n",
       " '领先地位': 279,\n",
       " '奠定': 280,\n",
       " '基础': 281,\n",
       " '业绩': 282,\n",
       " '年来': 283,\n",
       " '28日': 284,\n",
       " '超过': 285,\n",
       " '4月': 286,\n",
       " '27日': 287,\n",
       " '14日': 288,\n",
       " '30': 289,\n",
       " '史无前例': 290,\n",
       " '对此': 291,\n",
       " '反驳': 292,\n",
       " '撞击': 293,\n",
       " '更不': 294,\n",
       " '会引起': 295,\n",
       " '大规模': 296,\n",
       " '海啸': 297,\n",
       " '生物': 298,\n",
       " '灭绝': 299,\n",
       " '灾难': 300,\n",
       " '太空': 301,\n",
       " '2,美国': 302,\n",
       " 'ceo': 303,\n",
       " '杰出': 304,\n",
       " '竞争对手': 305,\n",
       " '博客': 306,\n",
       " '发展计划': 307,\n",
       " '赢得了': 308,\n",
       " '新浪博客': 309,\n",
       " '成功': 310,\n",
       " '影响力': 311,\n",
       " '10日': 312,\n",
       " '人物': 313,\n",
       " '荣誉': 314,\n",
       " '9月': 315,\n",
       " '升任': 316,\n",
       " '裁并': 317,\n",
       " '2006年': 318,\n",
       " '1,本报记者': 319}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = gensim.models.Word2Vec(\n",
    "    sentences, vector_size=50, window=5, min_count=1, workers=4)\n",
    "model.save('C000008.word2vec.model')\n",
    "model.wv.key_to_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "ee46d520",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "131\n",
      "[-0.01635404  0.00892374 -0.00824884  0.00170714  0.01705568 -0.00899647\n",
      "  0.00914393 -0.01339554 -0.00736877  0.01885619 -0.00314003  0.0005598\n",
      " -0.00817219 -0.01527859 -0.0030647   0.00505625 -0.00163121  0.01093926\n",
      " -0.00567137  0.00446063  0.01084758  0.01680043 -0.00283044 -0.018494\n",
      "  0.00881959  0.00127615  0.01478991 -0.00165017 -0.00545879 -0.01751602\n",
      " -0.00163091  0.00587183  0.01071152  0.01424571 -0.01153169  0.00400406\n",
      "  0.01238476 -0.00952362 -0.00622743  0.01345337  0.00357411  0.00042622\n",
      "  0.00689636  0.0005763   0.01942016  0.0100719  -0.01777284 -0.01420312\n",
      "  0.00203524  0.01278464]\n",
      "[-0.00724434  0.00724444 -0.00516048 -0.00945492 -0.00724321 -0.01235146\n",
      "  0.00705467 -0.01544216 -0.01511148 -0.00901232  0.01482126  0.00356372\n",
      "  0.01094765  0.01646764 -0.01249899 -0.00918809  0.01543008  0.00986083\n",
      "  0.00925157  0.01797456  0.01754135 -0.00537365  0.00210229  0.01097721\n",
      "  0.01779929  0.01981683 -0.01665249 -0.0102626   0.00993185  0.00185638\n",
      "  0.00029913  0.01508015  0.01280946 -0.01445732 -0.00522017 -0.0169047\n",
      "  0.01423509  0.00552213 -0.00447353  0.00368358  0.0154058   0.01627867\n",
      "  0.01233664  0.00638969  0.00320422  0.01628878  0.00260389  0.00081539\n",
      " -0.01899525  0.00505414]\n",
      "0.07167525\n"
     ]
    }
   ],
   "source": [
    "# key index\n",
    "print(model.wv.key_to_index['中国'])\n",
    "print(model.wv.key_to_index['澳大利亚'])\n",
    "\n",
    "# word vector\n",
    "print(model.wv['中国'])\n",
    "print(model.wv['澳大利亚'])\n",
    "\n",
    "# compare two word\n",
    "print(model.wv.similarity('中国', '澳大利亚'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0c42076",
   "metadata": {},
   "source": [
    "## PyTorch\n",
    "\n",
    "演示使用pytorch训练skip-gram的word2vec模型，比上一节的论文实现简化一些。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "8ce19025",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import torch.optim as optim\n",
    "import torch.nn as nn\n",
    "import torch\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"] = \"TRUE\"\n",
    "\n",
    "\n",
    "def random_batch():\n",
    "    random_inputs = []\n",
    "    random_labels = []\n",
    "    random_index = np.random.choice(\n",
    "        range(len(skip_grams)), batch_size, replace=False)\n",
    "\n",
    "    for i in random_index:\n",
    "        random_inputs.append(np.eye(voc_size)[skip_grams[i][0]])  # target\n",
    "        random_labels.append(skip_grams[i][1])  # context word\n",
    "\n",
    "    return random_inputs, random_labels\n",
    "\n",
    "\n",
    "class Word2Vec(nn.Module):\n",
    "    # Model\n",
    "    def __init__(self):\n",
    "        super(Word2Vec, self).__init__()\n",
    "        # W and WT is not Traspose relationship\n",
    "        # voc_size > embedding_size Weight\n",
    "        self.W = nn.Linear(voc_size, embedding_size, bias=False)\n",
    "        # embedding_size > voc_size Weight\n",
    "        self.WT = nn.Linear(embedding_size, voc_size, bias=False)\n",
    "\n",
    "    def forward(self, X):\n",
    "        # X : [batch_size, voc_size]\n",
    "        hidden_layer = self.W(X)  # hidden_layer : [batch_size, embedding_size]\n",
    "        # output_layer : [batch_size, voc_size]\n",
    "        output_layer = self.WT(hidden_layer)\n",
    "        return output_layer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e34043a",
   "metadata": {},
   "source": [
    "定义参数，开始训练："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "858fdfb6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1000 cost = 1.119079\n",
      "Epoch: 2000 cost = 1.241441\n",
      "Epoch: 3000 cost = 1.609436\n",
      "Epoch: 4000 cost = 1.662706\n",
      "Epoch: 5000 cost = 1.212765\n",
      "Epoch: 6000 cost = 0.920481\n",
      "Epoch: 7000 cost = 1.035215\n",
      "Epoch: 8000 cost = 1.205380\n",
      "Epoch: 9000 cost = 1.611376\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAD6CAYAAACs/ECRAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAgR0lEQVR4nO3de3xU9Z3/8deHcA8QsIACUsF9UJGEBEJAFJEorKWIIFar1l1Yf22zUNnidvVR9+FDi+26ta39oay6yFYFXC1afyLgpSoo5dJQCRDuIIJxYaEQEEK4CWM+vz8ySblMLsNMZpKc9/PxmEfOfOd7zvczBzLvnMucY+6OiIgEV5NkFyAiIsmlIBARCTgFgYhIwCkIREQCTkEgIhJwCgIRkYBTEEi9ZGYTzWx8nJZVZGYd47EskcbI6vP3CDp27Og9evRIdhnSwG3YsIErr7ySpk2bJrsUkTq3evXqA+7eKZp56vVvRo8ePSgoKEh2GRInt9xyC7t27eLkyZNMmTKFvLw82rRpw5QpU3jrrbdo1aoV8+fP5+KLL2bq1Km0adOG+++/n9zcXPr378/q1aspLi5mzpw5/OIXv2DDhg3ccccd/Nu//VuVy4fy/0eLFi2iY0dtFEjjZ2afRzuPdg1JwrzwwgusXr2agoICpk+fzsGDBzl27BiDBw9m3bp1XHfddfzXf/1XxHmbN2/O0qVLmThxImPHjuWZZ55h48aNzJo1i4MHD1a5fBGpmYJAEmb69OlkZWUxePBgdu3axfbt22nevDmjR48GYMCAARQVFUWcd8yYMQD07duX9PR0unTpQosWLbj88svZtWtXlcsXkZrV611D0ngsWbKERYsWkZ+fT+vWrcnNzeXkyZM0a9YMMwMgJSWFUCgUcf4WLVoA0KRJk8rpiuehUKjK5YtIzbRFIAlRUlJChw4daN26NVu3bmXlypUNavkijZmCQBJi5MiRhEIhMjMzefjhhxk8eHCDWr5IY1avTx/NyclxnTUk0Xhz7f/y6/e2sefwCbq2b8UD37yCW/p3S3ZZIgljZqvdPSeaeXSMQBqNN9f+L//6xgZOnP4KgP89fIJ/fWMDgMJApBraNSSNxq/f21YZAhVOnP6KX7+3LUkViTQMCgJpNPYcPhFVu4iUUxBIo9G1fauo2kWknIJAGo0HvnkFrZqlnNXWqlkKD3zziiRVJNIw6GCxNBoVB4R11pBIdBQE0qjc0r+bPvhFoqRdQyIiAacgEBEJOAWBiEjAKQhERAJOQSAiEnBxCQIze8HM9pvZxipeNzObbmafmtl6M8uOx7giIhK7eG0RzAJGVvP6t4Be4Uce8J9xGldERGIUlyBw96XAF9V0GQvM8XIrgfZm1iUeY4uISGwSdYygG7DrjOe7w23nMbM8Mysws4Li4uKEFCciEmSJCgKL0BbxjjjuPtPdc9w9p1OnTnVcloiIJCoIdgPdz3h+KbAnQWOLiEg1EhUEC4Dx4bOHBgMl7r43QWOLiEg14nLROTP7HZALdDSz3cBPgWYA7j4DeAcYBXwKHAfuice4IiISu7gEgbvfVcPrDtwbj7FERCS+9M1iEZGAUxCIiAScgkBEJOAUBCIiAacgEBEJOAWBNBpTp07liSeeSHYZIg2OgkBEJOAUBNKgPfbYY1xxxRWMGDGCbdu2AVBYWMjgwYPJzMxk3LhxHDp0CIBVq1aRmZnJ1VdfzQMPPEBGRkYySxepNxQE0mCtXr2auXPnsnbtWt544w1WrVoFwPjx4/nlL3/J+vXr6du3L48++igA99xzDzNmzCA/P5+UlJRkli5SrygIpMFatmwZ48aNo3Xr1rRr144xY8Zw7NgxDh8+zLBhwwCYMGECS5cu5fDhw5SWlnLNNdcA8N3vfjeZpYvUKwoCadDMIl3h/HzlVzkRkUgUBNJgXXfddcybN48TJ05QWlrKwoULSU1NpUOHDixbtgyAl156iWHDhtGhQwfatm3LypUrAZg7d24ySxepV+Jy0TmRZMjOzuaOO+6gX79+XHbZZQwdOhSA2bNnM3HiRI4fP87ll1/Oiy++CMDzzz/PD37wA1JTU8nNzSUtLS2Z5YvUG1afN5lzcnK8oKAg2WVIY7D+NY6+81PanNwLaZfyeFEme+nEU089lezKROLKzFa7e04082iLQBq/9a/Bwh/x9toj/GL5l4TKtnBZh+3Meu6ZZFcmUi8oCKTxW/wzOH2COzKacUdGs7+2r30Scr+ftLJE6gsdLJbGr2R3dO0iAaMgkMYv7dLo2kUCRkEgjd/wR6BZq7PbmrUqbxcRBYEEQOZ34ObpkNYdsPKfN08vbxcRHSyWgMj8jj74RaqgLQIRkYBTEIiIBJyCQEQk4OISBGY20sy2mdmnZvZghNdzzazEzArDD52uISJST8R8sNjMUoBngL8FdgOrzGyBu28+p+sydx8d63giIhJf8dgiGAR86u473f0UMBcYG4fliohIAsQjCLoBu854vjvcdq6rzWydmb1rZulVLczM8syswMwKiouL41CeiIhUJx5BEOkWUede23oNcJm7ZwH/AbxZ1cLcfaa757h7TqdOneJQnoiIVCceQbAb6H7G80uBPWd2cPcj7n40PP0O0MzMOsZhbBERiVE8gmAV0MvMeppZc+BOYMGZHczsEgvfXNbMBoXHPRiHsUVEJEYxnzXk7iEzmwy8B6QAL7j7JjObGH59BnAbMMnMQsAJ4E6vz7dGExEJEN2qUkSkEbmQW1Xqm8UiIgGnIBARCbgGEQRFRUVkZGQkuwyRuJs1axaTJ09OdhkScA0iCEREpO7EPQjMrIeZbTWz35rZRjN72cxGmNkKM9tuZoPM7CIze9PM1pvZSjPLDM871cxeMLMlZrZz//795y1/586d9O/fn1WrVrFjxw5GjhzJgAEDGDp0KFu3bqW0tJSePXty+vRpAI4cOUKPHj0qn4vEyy233MKAAQNIT09n5syZALRp04Z/+Zd/ITs7m+HDh1Px7fjc3Fzuu+8+rrnmGjIyMvj444/PW15xcTHf/va3GThwIAMHDmTFihUJfT9yYdydsrKyZJcRG3eP6wPoAYSAvpQHzWrgBcq/gTyW8m8V/wfw03D/G4DC8PRU4E9AC6BjSkqKnzp1yj/77DNPT0/3rVu3er9+/Xzt2rXu7n7DDTf4J5984u7uK1eu9Ouvv97d3f/hH/7B582b5+7uzz33nP/4xz92kXg7ePCgu7sfP37c09PT/cCBAw74f//3f7u7+6OPPur33nuvu7sPGzbMv//977u7+x//+EdPT093d/cXX3yxss9dd93ly5Ytc3f3zz//3Hv37p3Q9yNV+81vfuPp6emenp7u06ZN888++8x79+7tkyZN8n79+nlRUZFPnDjRBwwY4H369PFHHnmkct7LLrvMH3nkEe/fv79nZGT4li1b3N19//79PmLECO/fv7/n5eX517/+dS8uLnZ395deeskHDhzoWVlZnpeX56FQqNa1AgUe7ed2tDPUuMDyINh+xvM5wN3h6cuBQmAtcPkZfXYBaeEgeKiivWXLlr5r1y7/7LPPvHPnzn7FFVf4xo0b3d29tLTUW7Zs6VlZWZWPil+c5cuX+5gxY9zdffDgwb5hw4Zar0SR2vrpT3/qmZmZnpmZ6e3atfP8/Hxv0qSJnz592t3dd+zY4VlZWe5eHgSLFy+unLd79+5+6NChs4KgU6dOZ/1/7tq1qx85ciTh70vOVlBQ4BkZGX706FEvLS31Pn36+Jo1a9zMPD8/v7JfxR8GoVDIhw0b5uvWrXP38iCYPn26u7s/88wz/r3vfc/d3e+9917/93//d3d3f/fddx3w4uJi37x5s48ePdpPnTrl7u6TJk3y2bNn17reCwmCurpn8ZdnTJed8byM8i+xhSLMU/GFhjPnJRQq75qWlkb37t1ZsWIF6enplJWV0b59ewoLC89b0JAhQygqKuKPf/wjX331lQ40S9wtWbKERYsWkZ+fT+vWrcnNzeXkyZPn9Qt/of686UjPy8rKyM/Pp1WrVnVTtFyQ5cuXM27cOFJTUwG49dZbWbZsGZdddhmDBw+u7Pfaa68xc+ZMQqEQe/fuZfPmzWRmZlbOAzBgwADeeOONyuXOmzcPgJEjR9KhQwcAFi9ezOrVqxk4cCAAJ06coHPnznX6HpN1sHgpcDeU37QGOODuR6qboXnz5rz55pvMmTOHV155hXbt2tGzZ09+//vfA+VbNuvWravsP378eO666y7uueeeunoPEmAlJSV06NCB1q1bs3XrVlauXAmUf5i//vrrALzyyitce+21lfO8+uqrQPkHQFpaGmlpaWct88Ybb+Tpp5+ufB7pjxxJPK/iS7cVwQDw2Wef8cQTT7B48WLWr1/PTTfddNYfBi1atAAgJSWl8o/bqpbr7kyYMIHCwkIKCwvZtm0bU6dOjdO7iSxZQTAVyDGz9cDjwITazJSamspbb73FtGnTmD9/Pi+//DLPP/88WVlZpKenM3/+/Mq+d999N4cOHeKuu+6qkzcgwTZy5EhCoRCZmZk8/PDDlX8ZpqamsmnTJgYMGMCHH37II4/89WZ8HTp04JprrmHixIk8//zz5y1z+vTpFBQUkJmZSZ8+fZgxY0bC3o9U7brrruPNN9/k+PHjHDt2jHnz5jF06NCz+hw5coTU1FTS0tLYt28f7777bo3Lvfbaa3nttdcAeP/99zl06BAAw4cP5/XXX6fiZJkvvviCzz//PM7v6hzR7ktK5GPAgAG13i9W4fCCBf7J9Tf4tK7dfMzFF/vhBQuiXobIhUpNTY3YPmzYMF+1alWV8x1ds8/3/OLPvusnS33PL/7sR9fsq6sS5QJEOlhcccC/woQJE7x3794+atQoHzdunL/44ovuXn6MoOIg8KpVq3zYsGHu7r5v3z6/4YYbvH///n7fffd5ly5d/OTJk+7uPnfuXM/KyvK+fft6dnb2WcciasIFHCNoVNcaKlm4kL0PP8LPPy9i2bFjPHdpd3q2a0eXn/+MtJtvrsNKRcq1adOGo0ePnteem5vLE088QU7O+ZeAObZ2P4ff2I6f/uspiNasCe1v7UVq/7rdNyzJ8+WXX5KSkkLTpk3Jz89n0qRJPPbGYzy15in+cuwvXJJ6CVOyp3DT5TdFtdwLudZQowqC7TcMJ7Rnz3ntTbt2pdeHi+NZmkjc7H38Y746/OV57SntW9DlwUFJqEgSYfv27XznO9+hrKyM5s2b892HvssrR1/h5Fd/PbbQMqUlU6+ZGlUYXEgQ1NVZQ0kR2rs3qnaR+iBSCFTXLo1Dr169WLt2beXzG1+/8awQADj51UmeWvNU1FsF0WpUl5ho2qVLVO0i9UFK+xZRtUvj9Jdjf4mqPZ4aVRB0/uf7sJYtz2qzli3p/M/3JacgkVpo980eWLOzfxWtWRPafbNHcgqSpLgk9ZKo2uOpUQVB2s030+XnP6Np165gRtOuXXWgWOq91P6daX9rr8otgJT2LXSgOICmZE+hZcrZf8i2TGnJlOwpdT52ozpYLCLSkL298+2knDXUqA4Wi4g0ZDddflOdHxiOpFHtGhIRkegpCEREAk5BICIScAoCEZGAUxBInVmyZAl/+tOfkl2GiNQgLkFgZiPNbJuZfWpmD0Z43cxsevj19WaWHY9xpX5TEIg0DDEHgZmlAM8A3wL6AHeZWZ9zun0L6BV+5AH/Geu4kjxz5swhMzOTrKws/v7v/56FCxdy1VVX0b9/f0aMGMG+ffsoKipixowZTJs2jX79+rFs2bJkly0iVYjH9wgGAZ+6+04AM5tL+U3qN5/RZywwJ3yt7JVm1t7Muri7rgbXwGzatInHHnuMFStW0LFjR7744gvMjJUrV2Jm/Pa3v+VXv/oVv/nNb5g4cSJt2rTh/vvvT3bZIlKNeARBN8pvPl9hN3BVLfp0A84LAjPLo3yrga9//etxKE/i6cMPP+S2226jY8eOAFx00UVs2LCBO+64g71793Lq1Cl69uyZ5CpFJBrxOEZgEdrOvW5FbfqUN7rPdPccd8/p1KlTzMVJfLn7eTdd/6d/+icmT57Mhg0beO655yLexF1E6q94BMFuoPsZzy8Fzr07TG36SAMwfPhwXnvtNQ4ePAiU30+1pKSEbt26ATB79uzKvm3btqW0tDQpdYpI7cUjCFYBvcysp5k1B+4EFpzTZwEwPnz20GCgRMcHGqb09HQeeughhg0bRlZWFj/+8Y+ZOnUqt99+O0OHDq3cZQRw8803M2/ePB0sFqnn4nL1UTMbBTwJpAAvuPtjZjYRwN1nWPm+hKeBkcBx4B53r/Gyorr6aMOzfv16Fi9eTElJCWlpaQwfPpzMzMxklyUSGEm7+qi7vwO8c07bjDOmHbg3HmNJ/bV+/XoWLlzI6dOnASgpKWHhwoUACgORekzfLJa4Wbx4cWUIVDh9+jSLFy9OUkUiUhsKAombkpKSqNpFpH5QEEjcpKWlRdUuIvWDgkDiZvjw4TRr1uystmbNmjF8+PAkVSQitaFbVUrcVBwQ1llDIg2LgkDiKjMzUx/8Ig2Mdg2JSEIUFRWRkZGR7DIkAgWBiEjAKQhEJGFCoRATJkwgMzOT2267jePHj/Ozn/2MgQMHkpGRQV5eHhVXO8jNzeUnP/kJgwYN4hvf+EblZUqKiooYOnQo2dnZZGdnV978aMmSJeTm5nLbbbfRu3dv7r777splVTWGlFMQiEjCbNu2jby8PNavX0+7du149tlnmTx5MqtWrWLjxo2cOHGCt956q7J/KBTi448/5sknn+TRRx8FoHPnznzwwQesWbOGV199lR/96EeV/deuXcuTTz7J5s2b2blzJytWrACodgxREIhIAnXv3p0hQ4YA8Hd/93csX76cjz76iKuuuoq+ffvy4YcfsmnTpsr+t956KwADBgygqKgIKP+2+g9+8AP69u3L7bffzubNf70H1qBBg7j00ktp0qQJ/fr1q5ynujFEZw2JSAKdey8LM+OHP/whBQUFdO/enalTp551P4sWLVoAkJKSQigUAmDatGlcfPHFrFu3jrKyMlq2bHle/zPnOXnyZLVjiLYIRCSB/ud//of8/HwAfve733HttdcC0LFjR44ePcrrr79e4zJKSkro0qULTZo04aWXXuKrr76qtn/Fh340YwSNtghEJGGuvPJKZs+ezT/+4z/Sq1cvJk2axKFDh+jbty89evRg4MCBNS7jhz/8Id/+9rf5/e9/z/XXX09qamq1/du3b1+5K6m2YwRNXO5HUFd0PwIRicWWZR+xbO4cSg8eoO3XOjL0zvFcOfT6ZJdVp5J2PwIRkfpmy7KPeH/m04ROfQlA6YFi3p/5NECjD4No6RiBiDRKy+bOqQyBCqFTX7Js7pwkVVR/KQhEpFEqPXggqvYgUxCISKPU9msdo2oPMgWBiDRKQ+8cT9PmLc5qa9q8BUPvHJ+kiuovHSwWkUap4oBw0M4auhAKAhFptK4cer0++GtBu4ZERAJOQSAiEnAx7Roys4uAV4EeQBHwHXc/FKFfEVAKfAWEov3Wm4iI1J1YtwgeBBa7ey9gcfh5Va53934KARGR+iXWIBgLzA5PzwZuiXF5IiKSYLEGwcXuvhcg/LNzFf0ceN/MVptZXnULNLM8Mysws4Li4uIYyxMRkZrUeIzAzBYBl0R46aEoxhni7nvMrDPwgZltdfelkTq6+0xgJpRffTSKMURE5ALUGATuPqKq18xsn5l1cfe9ZtYF2F/FMvaEf+43s3nAICBiEIiISGLFumtoATAhPD0BmH9uBzNLNbO2FdPAjcDGGMcVEZE4iTUIHgf+1sy2A38bfo6ZdTWzd8J9LgaWm9k64GPgbXf/Q4zjiohInMT0PQJ3PwgMj9C+BxgVnt4JZMUyjoiI1B19s1hEJOAUBCIiAacgEBEJOAWBiEjAKQhERAJOQSAiEnAKAhGRgFMQiIgEnIJARCTgFAQiIgGnIBARCTgFgYhIwCkIREQCTkEgIhJwCgIRkYBTEIiIBJyCQEQk4BQEIiIBpyAQEQk4BYGISMApCEREAk5BICIScAoCEZGAUxCIiARcTEFgZreb2SYzKzOznGr6jTSzbWb2qZk9GMuYIiISX7FuEWwEbgWWVtXBzFKAZ4BvAX2Au8ysT4zjiohInDSNZWZ33wJgZtV1GwR86u47w33nAmOBzbGMLSIi8ZGIYwTdgF1nPN8dbovIzPLMrMDMCoqLi+u8OBGRoKtxi8DMFgGXRHjpIXefX4sxIm0ueFWd3X0mMBMgJyenyn4iIhIfNQaBu4+IcYzdQPcznl8K7IlxmSIiEieJ2DW0CuhlZj3NrDlwJ7AgAeOKiEgtxHr66Dgz2w1cDbxtZu+F27ua2TsA7h4CJgPvAVuA19x9U2xli4hIvMR61tA8YF6E9j3AqDOevwO8E8tYIiJSN/TNYhGRgFMQiIgEnIJARCTgFAQiIgGnIBARCTgFgYhIwCkIREQCTkEgIhJwCgIRkYBTEIiIBJyCQEQk4BQEIiIBpyAQEQk4BYGISMApCEREAk5BICIScAoCEZGAUxCIiAScgkBEJOAUBCIiAacgEBEJOAWBiEjAKQhERAJOQSAiEnAxBYGZ3W5mm8yszMxyqulXZGYbzKzQzApiGVNEROKraYzzbwRuBZ6rRd/r3f1AjOOJiEicxRQE7r4FwMziU42IiCRcoo4ROPC+ma02s7zqOppZnpkVmFlBcXFxgsoTEQmuGrcIzGwRcEmElx5y9/m1HGeIu+8xs87AB2a21d2XRuro7jOBmQA5OTley+WLiMgFqjEI3H1ErIO4+57wz/1mNg8YBEQMAhERSaw63zVkZqlm1rZiGriR8oPMIiJSD8R6+ug4M9sNXA28bWbvhdu7mtk74W4XA8vNbB3wMfC2u/8hlnFFRCR+Yj1raB4wL0L7HmBUeHonkBXLOCIiUnf0zWIRkYBTEIiIBJyCQEQk4BQEIvXM9OnTufLKK7n77rtrPc+oUaM4fPgwhw8f5tlnn63D6qQxMvf6+52tnJwcLyjQNeokWHr37s27775Lz549K9tCoRBNm9Z8bkdRURGjR49m40adoR1UZrba3au8CGgk2iIQqUcmTpzIzp07GTNmDGlpaeTl5XHjjTcyfvx4Zs2axeTJkyv7jh49miVLlgDQo0cPDhw4wIMPPsiOHTvo168fDzzwQJLehTQ0sV59VETiaMaMGfzhD3/go48+4umnn2bhwoUsX76cVq1aMWvWrBrnf/zxx9m4cSOFhYV1Xqs0HtoiEKnHxowZQ6tWrZJdhjRyCgKReiw1NbVyumnTppSVlVU+P3nyZDJKkkZIQSDSQPTo0YPCwkLKysrYtWsXH3/88Xl92rZtS2lpaRKqk4ZMQSDSQAwZMoSePXvSt29f7r//frKzs8/r87WvfY0hQ4aQkZGhg8VSazp9VKQR+OTPfyF//g6OfvElbS5qwdVj/4ZvXBXpNiLS2F3I6aM6a0ikgfvkz3/ho5e3EjpVfvzg6Bdf8tHLWwEUBlIr2jUk0sDlz99RGQIVQqfKyJ+/I0kVSUOjIBBp4I5+8WVU7SLnUhCINHBtLmoRVbvIuRQEIg3c1WP/hqbNz/5Vbtq8CVeP/ZskVSQNjQ4WizRwFQeEddaQXCgFgUgj8I2rLtEHv1ww7RoSEQk4BYGISMApCEREAk5BICIScAoCEZGAq9cXnTOzYuDzJA3fETiQpLGrUh9rAtUVjfpYE6iuaNTHmuCvdV3m7p2imbFeB0EymVlBtFfwq2v1sSZQXdGojzWB6opGfawJYqtLu4ZERAJOQSAiEnAKgqrNTHYBEdTHmkB1RaM+1gSqKxr1sSaIoS4dIxARCThtEYiIBJyCQEQk4BQEYWb2azPbambrzWyembWvot9IM9tmZp+a2YN1XNPtZrbJzMrMrMrTwsysyMw2mFmhmRXUZU1R1pWwdRUe7yIz+8DMtod/dqiiX52vr5reu5WbHn59vZll10UdF1BXrpmVhNdNoZk9koCaXjCz/Wa2sYrXE76ualFTwtdTeNzuZvaRmW0J/w5OidAn+vXl7nqUHye5EWganv4l8MsIfVKAHcDlQHNgHdCnDmu6ErgCWALkVNOvCOiYwHVVY12JXlfhMX8FPBiefjDSv2Ei1ldt3jswCngXMGAw8OcE/LvVpq5c4K1E/V8Kj3kdkA1srOL1ZKyrmmpK+HoKj9sFyA5PtwU+icf/LW0RhLn7++4eCj9dCVwaodsg4FN33+nup4C5wNg6rGmLu2+rq+VfqFrWldB1FTYWmB2eng3cUsfjVaU2730sMMfLrQTam1mXelBXwrn7UuCLarokfF3VoqakcPe97r4mPF0KbAG6ndMt6vWlIIjs/1CeqOfqBuw64/luzv9HSAYH3jez1WaWl+xiwpKxri52971Q/gsDdK6iX12vr9q892Ssn9qOebWZrTOzd80svY5rqo36+nuX1PVkZj2A/sCfz3kp6vUVqDuUmdkiINJtnB5y9/nhPg8BIeDlSIuI0BbT+be1qakWhrj7HjPrDHxgZlvDf9Eks664ryuovq4oFhP39XWO2rz3Olk/NajNmGsov1bNUTMbBbwJ9KrjumqSjHVVk6SuJzNrA/w/4D53P3LuyxFmqXZ9BSoI3H1Eda+b2QRgNDDcwzvbzrEb6H7G80uBPXVZUy2XsSf8c7+ZzaN8F0BMH2xxqCvu6wqqr8vM9plZF3ffG94U3l/FMuK+vs5Rm/deJ+sn1rrO/FBx93fM7Fkz6+juybzIWjLWVbWSuZ7MrBnlIfCyu78RoUvU60u7hsLMbCTwE2CMux+votsqoJeZ9TSz5sCdwIJE1RiJmaWaWduKacoPekc80yHBkrGuFgATwtMTgPO2XBK0vmrz3hcA48NneAwGSip2a9WhGusys0vMzMLTgyj/jDhYx3XVJBnrqlrJWk/hMZ8Htrj7/62iW/TrK9FHvevrA/iU8v1qheHHjHB7V+CdM/qNovxI/Q7Kd5PUZU3jKE/3L4F9wHvn1kT5GSDrwo9NdV1TbetK9LoKj/c1YDGwPfzzomStr0jvHZgITAxPG/BM+PUNVHNWWILrmhxeL+soP2nimgTU9DtgL3A6/P/qe8leV7WoKeHrKTzutZTv5ll/xmfVqFjXly4xISIScNo1JCIScAoCEZGAUxCIiAScgkBEJOAUBCIiAacgEBEJOAWBiEjA/X98Bvj7/Zm/XwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "batch_size = 2  # mini-batch size\n",
    "embedding_size = 10  # embedding size\n",
    "\n",
    "sentences = [\"apple banana fruit\", \"banana orange fruit\", \"orange banana fruit\",\n",
    "             \"dog cat animal\", \"cat monkey animal\", \"monkey dog animal\"]\n",
    "\n",
    "word_sequence = \" \".join(sentences).split()\n",
    "word_list = \" \".join(sentences).split()\n",
    "word_list = list(set(word_list))\n",
    "word_dict = {w: i for i, w in enumerate(word_list)}\n",
    "voc_size = len(word_list)\n",
    "\n",
    "# Make skip gram of one size window\n",
    "skip_grams = []\n",
    "for i in range(1, len(word_sequence) - 1):\n",
    "    target = word_dict[word_sequence[i]]\n",
    "    context = [word_dict[word_sequence[i - 1]],\n",
    "               word_dict[word_sequence[i + 1]]]\n",
    "    for w in context:\n",
    "        skip_grams.append([target, w])\n",
    "\n",
    "model = Word2Vec()\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "# Training\n",
    "for epoch in range(9000):\n",
    "    input_batch, target_batch = random_batch()\n",
    "    input_batch = torch.Tensor(input_batch)\n",
    "    target_batch = torch.LongTensor(target_batch)\n",
    "\n",
    "    optimizer.zero_grad()\n",
    "    output = model(input_batch)\n",
    "\n",
    "    # output : [batch_size, voc_size], target_batch : [batch_size] (LongTensor, not one-hot)\n",
    "    loss = criterion(output, target_batch)\n",
    "    if (epoch + 1) % 1000 == 0:\n",
    "        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n",
    "\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "for i, label in enumerate(word_list):\n",
    "    W, WT = model.parameters()\n",
    "    x, y = W[0][i].item(), W[1][i].item()\n",
    "    plt.scatter(x, y)\n",
    "    plt.annotate(label, xy=(x, y), xytext=(5, 2),\n",
    "                 textcoords='offset points', ha='right', va='bottom')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ce015dc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.remove('word2vec.model')\n",
    "os.remove('word2vec.wordvectors')\n",
    "os.remove('C000008.word2vec.model')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1798e937",
   "metadata": {},
   "source": [
    "本节完。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ded1878",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}